// Copyright 2021 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package spanconfigsqlwatcher

import (
	"context"
	"sort"

	"github.com/cockroachdb/cockroach/pkg/kv/kvclient/rangefeed/rangefeedbuffer"
	"github.com/cockroachdb/cockroach/pkg/spanconfig"
	"github.com/cockroachdb/cockroach/pkg/sql/catalog"
	"github.com/cockroachdb/cockroach/pkg/util/hlc"
	"github.com/cockroachdb/cockroach/pkg/util/log"
	"github.com/cockroachdb/cockroach/pkg/util/syncutil"
)

// buffer is a helper struct for the SQLWatcher. It buffers events generated by
// the SQLWatcher's rangefeeds over system.zones and system.descriptors. It is
// safe for concurrent use.
//
// The buffer tracks frontier timestamps for both these rangefeeds as well. It
// maintains the notion of the combined frontier timestamp computed as the
// minimum of the two. This is used when flushing the buffer periodically.
type buffer struct {
	mu struct {
		syncutil.Mutex

		// rangefeed.Buffer stores spanconfigsqlwatcher.Events.
		buffer *rangefeedbuffer.Buffer[event]

		// rangefeedFrontiers tracks the frontier timestamps of individual
		// rangefeeds established by the SQLWatcher.
		rangefeedFrontiers [numRangefeeds]hlc.Timestamp
	}
}

// event is the unit produced by the rangefeeds the SQLWatcher establishes over
// system.protected_ts_records, system.zones and system.descriptors. It
// implements the rangefeedbuffer.Event interface.
type event struct {
	// timestamp at which the event was generated by the rangefeed.
	timestamp hlc.Timestamp

	// update captures information about the descriptor or zone or protected
	// timestamp record that the SQLWatcher has observed change.
	update spanconfig.SQLUpdate
}

// Timestamp implements the rangefeedbuffer.Event interface.
func (e event) Timestamp() hlc.Timestamp {
	return e.timestamp
}

// rangefeedKind is used to identify the distinct rangefeeds {descriptors,
// zones, protected_ts_records} established by the SQLWatcher.
type rangefeedKind int

const (
	zonesRangefeed rangefeedKind = iota
	descriptorsRangefeed
	protectedTimestampRangefeed

	// numRangefeeds should be listed last.
	numRangefeeds int = iota
)

// newBuffer constructs a new buffer initialized with a starting frontier
// timestamp.
func newBuffer(limit int, initialFrontierTS hlc.Timestamp) *buffer {
	rangefeedBuffer := rangefeedbuffer.New[event](limit)
	eventBuffer := &buffer{}
	eventBuffer.mu.buffer = rangefeedBuffer
	for i := range eventBuffer.mu.rangefeedFrontiers {
		eventBuffer.mu.rangefeedFrontiers[i].Forward(initialFrontierTS)
	}
	return eventBuffer
}

// advance advances the frontier for the given rangefeed.
func (b *buffer) advance(rangefeed rangefeedKind, timestamp hlc.Timestamp) {
	b.mu.Lock()
	defer b.mu.Unlock()
	b.mu.rangefeedFrontiers[rangefeed].Forward(timestamp)
}

// add records the given event in the buffer.
func (b *buffer) add(ev event) error {
	b.mu.Lock()
	defer b.mu.Unlock()
	return b.mu.buffer.Add(ev)
}

type events []event

// Len implements the sort.Interface methods.
func (s events) Len() int {
	return len(s)
}

// Less implements the sort.Interface methods.
func (s events) Less(i, j int) bool {
	ei, ej := s[i], s[j]
	if ei.update.IsDescriptorUpdate() && ej.update.IsDescriptorUpdate() {
		descUpdatei := ei.update.GetDescriptorUpdate()
		descUpdatej := ej.update.GetDescriptorUpdate()
		if descUpdatei.ID == descUpdatej.ID {
			return ei.timestamp.Less(ej.timestamp)
		}
		return descUpdatei.ID < descUpdatej.ID
	}

	// If the LHS is a descriptor update, sort it before the RHS which is a
	// protected timestamp update.
	if ei.update.IsDescriptorUpdate() {
		return true
	}

	// If the RHS is a descriptor update, sort it before the LHS which is a
	// protected timestamp update.
	if ej.update.IsDescriptorUpdate() {
		return false
	}

	// At this point, both ei and ej represent ProtectedTimestampUpdates.
	lhsUpdate := ei.update.GetProtectedTimestampUpdate()
	rhsUpdate := ej.update.GetProtectedTimestampUpdate()
	if lhsUpdate.IsTenantsUpdate() && rhsUpdate.IsTenantsUpdate() {
		if lhsUpdate.TenantTarget == rhsUpdate.TenantTarget {
			return ei.timestamp.Less(ej.timestamp)
		}
		return lhsUpdate.TenantTarget.ToUint64() < rhsUpdate.TenantTarget.ToUint64()
	}

	// If the LHS is a tenant target ProtectedTimestampUpdate, sort it before the
	// RHS which is a cluster target ProtectedTimestampUpdate.
	if lhsUpdate.IsTenantsUpdate() {
		return true
	}

	// If the RHS is a tenant target ProtectedTimestampUpdate, sort it before the
	// LHS which is a cluster target ProtectedTimestampUpdate.
	if rhsUpdate.IsTenantsUpdate() {
		return false
	}

	// Finally, if both LHS and RHS are cluster target ProtectedTimestampUpdates,
	// sort on timestamp.
	return ei.timestamp.Less(ej.timestamp)
}

// Swap implements the sort.Interface methods.
func (s events) Swap(i, j int) {
	s[i], s[j] = s[j], s[i]
}

var _ sort.Interface = &events{}

// flushEvents computes the combined frontier timestamp of the buffer and
// returns  a list of relevant events which were buffered up to that timestamp.
func (b *buffer) flushEvents(
	ctx context.Context,
) (updates []event, combinedFrontierTS hlc.Timestamp) {
	b.mu.Lock()
	defer b.mu.Unlock()
	// First we determine the checkpoint timestamp, which is the minimum
	// checkpoint timestamp of all event types.
	combinedFrontierTS = hlc.MaxTimestamp
	for _, ts := range b.mu.rangefeedFrontiers {
		combinedFrontierTS.Backward(ts)
	}

	return b.mu.buffer.Flush(ctx, combinedFrontierTS), combinedFrontierTS
}

// flush computes the combined frontier timestamp of the buffer and returns a
// list of unique spanconfig.DescriptorUpdates below this timestamp. The
// combined frontier timestamp is also returned.
func (b *buffer) flush(
	ctx context.Context,
) (sqlUpdates []spanconfig.SQLUpdate, _ hlc.Timestamp, _ error) {
	bufferedEvents, combinedFrontierTS := b.flushEvents(ctx)
	evs := make(events, 0, len(bufferedEvents))
	for i, ev := range bufferedEvents {
		if i != 0 {
			prevEv := bufferedEvents[i-1]
			if !prevEv.Timestamp().LessEq(ev.Timestamp()) {
				log.Fatalf(ctx, "expected events to be sorted by timestamp but found %+v after %+v",
					ev, prevEv)
			}
		}
		evs = append(evs, ev)
	}
	// Nil out the underlying slice since we have copied over the events.
	bufferedEvents = nil

	// The events slice can contain a SQLUpdate that applies to a descriptor, or a
	// SQLUpdate that applies to a protected timestamp record. We sort the
	// protected timestamp SQLUpdates to the end of the slice, so that we can
	// perform the subsequent deduplication of SQLUpdates. We do this instead of
	// pre-processing events to separate the two kinds of updates, to save on the
	// allocation of an additional `events` slices.
	sort.Sort(evs)

	// Find the index before which all events are SQLUpdates on descriptors.
	descriptorUpdatesIdx := sort.Search(len(evs), func(i int) bool {
		update := evs[i].update
		return update.IsProtectedTimestampUpdate()
	})

	// Deduplicate the SQLUpdates emitted by the buffer that apply to descriptors.
	for i, ev := range evs[:descriptorUpdatesIdx] {
		update := ev.update
		descriptorUpdate := update.GetDescriptorUpdate()
		if i == 0 {
			sqlUpdates = append(sqlUpdates, update)
			continue
		}

		prevUpdate := evs[i-1].update
		if prevUpdate.GetDescriptorUpdate().ID != descriptorUpdate.ID {
			sqlUpdates = append(sqlUpdates, update)
			continue
		}

		prevDescriptorSQLUpdate := sqlUpdates[len(sqlUpdates)-1].GetDescriptorUpdate()
		descType, err := combine(prevDescriptorSQLUpdate.Type,
			descriptorUpdate.Type)
		if err != nil {
			return nil, hlc.Timestamp{}, err
		}
		sqlUpdates[len(sqlUpdates)-1] = spanconfig.MakeDescriptorSQLUpdate(
			prevDescriptorSQLUpdate.ID, descType)
	}

	// Truncate the slice to only include ProtectedTimestampUpdates now that we
	// have copied over all de-duped DescriptorUpdates.
	evs = evs[descriptorUpdatesIdx:]

	// Deduplicate the SQLUpdates emitted by the buffer that apply to protected
	// timestamps.
	for i, ev := range evs {
		update := ev.update
		curUpdate := update.GetProtectedTimestampUpdate()
		if i == 0 {
			sqlUpdates = append(sqlUpdates, update)
			continue
		}

		prevUpdate := evs[i-1].update.GetProtectedTimestampUpdate()
		// If the previous buffered event, and the current event are both tenant
		// target ProtectedTimestampUpdates, compare tenantIDs and ignore
		// duplicates.
		if prevUpdate.IsTenantsUpdate() && curUpdate.IsTenantsUpdate() {
			if prevUpdate.TenantTarget == curUpdate.TenantTarget {
				continue
			} else {
				sqlUpdates = append(sqlUpdates, update)
			}
		}

		// If the previous buffered event, and the current event are both cluster
		// target ProtectedTimestampUpdates we can ignore duplicates.
		if prevUpdate.IsClusterUpdate() && curUpdate.IsClusterUpdate() {
			continue
		}

		// The previous buffered event and the current event are different
		// target ProtectedTimestampUpdates, there is no deduplication possible.
		sqlUpdates = append(sqlUpdates, update)
	}

	return sqlUpdates, combinedFrontierTS, nil
}

// combine takes two catalog.DescriptorTypes and combines them according to the
// following semantics:
// - Any can combine with any concrete descriptor type (including itself).
// Concrete descriptor types are {Table,Database,Schema,Type} descriptor types.
// - Concrete descriptor types can combine with themselves.
// - A concrete descriptor type cannot combine with another concrete descriptor
// type.
func combine(d1, d2 catalog.DescriptorType) (catalog.DescriptorType, error) {
	if d1 == d2 {
		return d1, nil
	}
	if d1 == catalog.Any {
		return d2, nil
	}
	if d2 == catalog.Any {
		return d1, nil
	}
	return catalog.Any, spanconfig.NewMismatchedDescriptorTypesError(d1, d2)
}
